In [60]:
import pandas as pd
import numpy as np
import altair as alt
import theme
from natsort import natsorted, natsort_keygen
alt.data_transformers.disable_max_rows()
Out[60]:
DataTransformerRegistry.enable('default')
In [61]:
mutation_effects = pd.read_csv('../results/combined_effects/combined_mutation_effects.csv')
mutation_effects.head()
Out[61]:
| mutant | struct_site | h3_wt_aa | h5_wt_aa | h7_wt_aa | rmsd_h3h5 | rmsd_h3h7 | rmsd_h5h7 | 4o5n_aa_RSA | 4kwm_aa_RSA | 6ii9_aa_RSA | h3_effect | h3_effect_std | h5_effect | h5_effect_std | h7_effect | h7_effect_std | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | A | 9 | S | K | NaN | 9.1674 | NaN | NaN | 1.084277 | 1.140252 | NaN | 0.0151 | 0.7225 | 0.0558 | 0.29180 | NaN | NaN |
| 1 | C | 9 | S | K | NaN | 9.1674 | NaN | NaN | 1.084277 | 1.140252 | NaN | -0.4080 | 0.3850 | -0.4245 | 0.02737 | NaN | NaN |
| 2 | D | 9 | S | K | NaN | 9.1674 | NaN | NaN | 1.084277 | 1.140252 | NaN | 0.2361 | 0.2740 | 0.2039 | 0.07884 | NaN | NaN |
| 3 | E | 9 | S | K | NaN | 9.1674 | NaN | NaN | 1.084277 | 1.140252 | NaN | -0.2463 | 0.8478 | 0.1713 | 0.10210 | NaN | NaN |
| 4 | F | 9 | S | K | NaN | 9.1674 | NaN | NaN | 1.084277 | 1.140252 | NaN | 0.2061 | 0.3214 | -0.8397 | 1.34800 | NaN | NaN |
In [62]:
site_effects = pd.read_csv('../results/combined_effects/combined_site_effects.csv')
site_effects.head()
Out[62]:
| struct_site | h3_wt_aa | h5_wt_aa | h7_wt_aa | rmsd_h3h5 | rmsd_h3h7 | rmsd_h5h7 | 4o5n_aa_RSA | 4kwm_aa_RSA | 6ii9_aa_RSA | avg_h3_effect | avg_h5_effect | avg_h7_effect | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 9 | S | K | NaN | 9.167400 | NaN | NaN | 1.084277 | 1.140252 | NaN | -0.050776 | -1.062932 | NaN |
| 1 | 10 | T | S | NaN | 8.157247 | NaN | NaN | 0.150962 | 0.175962 | NaN | -0.697911 | -3.224739 | NaN |
| 2 | 11 | A | D | D | 5.040040 | 2.984626 | 2.886615 | 0.050388 | 0.097927 | 0.624352 | -3.138280 | -3.921267 | -2.963026 |
| 3 | 12 | T | Q | K | 3.937602 | 1.626754 | 3.384350 | 0.268605 | 0.216889 | 0.368644 | -1.036219 | -0.467449 | -1.727747 |
| 4 | 13 | L | I | I | 3.687798 | 1.734039 | 2.549524 | 0.000000 | 0.000000 | 0.000000 | -3.941050 | -3.885729 | -3.840728 |
In [63]:
# Read in protein sequence identities
seq_identity = pd.read_csv('../results/sequence_identity/ha_sequence_identity.csv')
seq_identity.head()
Out[63]:
| ha_x | ha_y | matches | alignable_residues | percent_identity | |
|---|---|---|---|---|---|
| 0 | H3 | H5 | 192.0 | 479.0 | 40.083507 |
| 1 | H3 | H7 | 229.0 | 483.0 | 47.412008 |
| 2 | H5 | H7 | 202.0 | 473.0 | 42.706131 |
In [64]:
h3_h7_scatter = alt.Chart(mutation_effects).mark_circle(
size=25, opacity=0.3, color='#767676'
).encode(
x=alt.X('h3_effect', title=['Effect on MDCK-SIAT1 entry', 'in H3 background']),
y=alt.Y('h7_effect', title=['Effect on 293-a2,6 entry', 'in H7 background']),
tooltip=['struct_site', 'mutant', 'h3_wt_aa', 'h7_wt_aa', 'h3_effect', 'h7_effect']
).properties(
width=200,
height=200,
title='H3 vs. H7'
)
h3_h5_scatter = alt.Chart(mutation_effects).mark_circle(
size=25, opacity=0.3, color='#767676'
).encode(
x=alt.X('h3_effect', title=['Effect on MDCK-SIAT1 entry', 'in H3 background']),
y=alt.Y('h5_effect', title=['Effect on 293T entry', 'in H5 background']),
tooltip=['struct_site', 'mutant', 'h3_wt_aa', 'h5_wt_aa', 'h3_effect', 'h5_effect']
).properties(
width=200,
height=200,
title='H3 vs. H5'
)
h5_h7_scatter = alt.Chart(mutation_effects).mark_circle(
size=25, opacity=0.3, color='#767676'
).encode(
x=alt.X('h5_effect', title=['Effect on 293T entry', 'in H5 background']),
y=alt.Y('h7_effect', title=['Effect on 293-a2,6 entry', 'in H7 background']),
tooltip=['struct_site', 'mutant', 'h5_wt_aa', 'h7_wt_aa', 'h5_effect', 'h7_effect']
).properties(
width=200,
height=200,
title='H5 vs. H7'
)
h3_h7_scatter | h3_h5_scatter | h5_h7_scatter
Out[64]:
In [65]:
def scatter_and_density_plot(df, ha_x, ha_y, colors):
r_value = df[f'avg_{ha_x}_effect'].corr(df[f'avg_{ha_y}_effect'])
r_text = f"r = {r_value:.2f}"
identity_line = alt.Chart(pd.DataFrame({'x': [-5, 0.3], 'y': [-5, 0.3]})).mark_line(
strokeDash=[6, 6],
color='black'
).encode(
x='x',
y='y'
)
df = df.assign(
same_wildtype= lambda x: np.where(
x[f'{ha_x}_wt_aa'] == x[f'{ha_y}_wt_aa'],
'Amino acid conserved',
'Amino acid changed'
),
)
scatter = alt.Chart(df).mark_circle(
size=35, opacity=1, stroke='black', strokeWidth=0.5
).encode(
x=alt.X(f'avg_{ha_x}_effect', title=['Mean effect on cell entry', f'in {ha_x.upper()} background']),
y=alt.Y(f'avg_{ha_y}_effect', title=['Mean effect on cell entry', f'in {ha_y.upper()} background']),
color=alt.Color(
'same_wildtype:N',
scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
),
tooltip=['struct_site', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', f'avg_{ha_x}_effect', f'avg_{ha_y}_effect']
).properties(
width=175,
height=175,
)
r_label = alt.Chart(pd.DataFrame({'text': [r_text]})).mark_text(
align='left',
baseline='top',
fontSize=16,
fontWeight='normal',
color='black'
).encode(
text='text:N',
x=alt.value(5),
y=alt.value(5)
)
x_density = alt.Chart(df).transform_density(
density=f'avg_{ha_x}_effect',
bandwidth=0.3,
groupby=['same_wildtype'],
extent=[df[f'avg_{ha_x}_effect'].min(), df[f'avg_{ha_x}_effect'].max()],
counts=True,
steps=200
).mark_area(opacity=0.6, color='black', strokeWidth=1).encode(
alt.X('value:Q', axis=alt.Axis(labels=False, title=None, ticks=False)),
alt.Y('density:Q', title='Density').stack(None),
color=alt.Color(
'same_wildtype:N',
title=None,
scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
),
).properties(
width=175,
height=50
)
y_density = alt.Chart(df).transform_density(
density=f'avg_{ha_y}_effect',
bandwidth=0.3,
groupby=['same_wildtype'],
extent=[df[f'avg_{ha_y}_effect'].min(), df[f'avg_{ha_y}_effect'].max()],
counts=True,
steps=200
).mark_area(opacity=0.6, color='black', strokeWidth=1, orient='horizontal').encode(
alt.Y('value:Q', axis=alt.Axis(labels=False, title=None, ticks=False)),
alt.X('density:Q', title='Density').stack(None),
color=alt.Color(
'same_wildtype:N',
title=None,
scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
),
).properties(
width=50,
height=175
)
marginal_plot = alt.vconcat(
x_density,
alt.hconcat(
(scatter + identity_line + r_label),
y_density
)
)
return marginal_plot
colors = {
'Amino acid changed' : '#5484AF',
'Amino acid conserved' : '#E04948'
}
p1 = scatter_and_density_plot(site_effects, 'h3', 'h5', colors=colors)
p2 = scatter_and_density_plot(site_effects, 'h3', 'h7', colors=colors)
p3 = scatter_and_density_plot(site_effects, 'h5', 'h7', colors=colors)
p1 | p2 | p3
Out[65]:
Calculate Jensen-Shannon Divergence¶
In [66]:
def kl_divergence(p, q):
return np.sum(p * np.log(p / q))
def compute_js_divergence_per_site(df, ha_x, ha_y, site_col="struct_site", min_mutations=15):
"""Compute JS divergence at each site and merge it back to the dataframe."""
js_per_site = {}
for site, group in df.groupby(site_col):
valid = group.dropna(subset=[f'{ha_x}_effect', f'{ha_y}_effect'])
js_div = np.nan
if len(valid) >= min_mutations:
p = np.exp(valid[f'{ha_x}_effect'].values)
q = np.exp(valid[f'{ha_y}_effect'].values)
p /= p.sum()
q /= q.sum()
m = 0.5 * (p + q)
js_div = 0.5 * (kl_divergence(p, m) + kl_divergence(q, m))
js_per_site[site] = js_div
# Create a column with the JS divergence duplicated across each row at the same site
df = df.copy()
col_name = f"JS_{ha_x}_vs_{ha_y}"
df[col_name] = df[site_col].map(js_per_site)
return df
js_df_h3_h7 = compute_js_divergence_per_site(mutation_effects, 'h3', 'h7', min_mutations=10)
js_df_h3_h5 = compute_js_divergence_per_site(mutation_effects, 'h3', 'h5', min_mutations=10)
js_df_h5_h7 = compute_js_divergence_per_site(mutation_effects, 'h5', 'h7', min_mutations=10)
Are epistatic shifts significant?¶
In [67]:
def compute_jsd_with_null(
df,
ha_x,
ha_y,
site_col="struct_site",
min_mutations=15,
n_bootstrap=1000,
random_seed=42,
jsd_threshold=0.02
):
"""
Compute JS divergence with bootstrap null distribution for significance testing.
The null distribution represents: "What JSD would I observe from measurement noise alone?"
The null is generated by computing two separate null distributions:
1. ha_x null: Sample ha_x twice with its measurement error, compute JSD
2. ha_y null: Sample ha_y twice with its measurement error, compute JSD
3. Take the mean of the two null distributions (balanced approach)
This accounts for measurement noise from both experiments without assuming they have
identical underlying effects. A significant result means the observed JSD is larger
than what measurement noise alone could produce.
Only sites with observed JSD > jsd_threshold are tested for significance.
Parameters
----------
df : pd.DataFrame
Dataframe with mutation effects and effect_std columns
ha_x, ha_y : str
HA subtype names (e.g., 'h3', 'h5')
site_col : str
Column name for site identifier
min_mutations : int
Minimum number of mutations required at a site
n_bootstrap : int
Number of bootstrap iterations
random_seed : int
Random seed for reproducibility
jsd_threshold : float
Minimum JSD value for a site to be tested for significance.
Sites with observed JSD <= threshold will have p_value = NaN.
Default is 0.02.
Returns
-------
pd.DataFrame
DataFrame with columns:
- struct_site: site identifier
- JS_observed: observed JSD value
- JS_null_mean: mean of null distribution (NaN if below threshold)
- JS_null_std: standard deviation of null distribution (NaN if below threshold)
- p_value: empirical p-value (NaN if below threshold)
- n_mutations: number of mutations at site
Sorted by struct_site using natural sorting.
"""
np.random.seed(random_seed)
def compute_jsd_vectorized(effects, std, n_bootstrap):
"""Vectorized computation of null JSD distribution."""
n_mutations = len(effects)
# Generate all bootstrap samples at once: shape (n_bootstrap, n_mutations)
effects_1 = np.random.normal(
loc=effects[np.newaxis, :], # broadcast to (1, n_mutations)
scale=std[np.newaxis, :], # broadcast to (1, n_mutations)
size=(n_bootstrap, n_mutations)
)
effects_2 = np.random.normal(
loc=effects[np.newaxis, :],
scale=std[np.newaxis, :],
size=(n_bootstrap, n_mutations)
)
# Compute probabilities for all bootstraps at once
p1 = np.exp(effects_1)
p2 = np.exp(effects_2)
# Normalize: divide each row by its sum
p1 = p1 / p1.sum(axis=1, keepdims=True)
p2 = p2 / p2.sum(axis=1, keepdims=True)
# Compute mixture distribution
m = 0.5 * (p1 + p2)
# Compute KL divergences (vectorized)
# KL(p||m) = sum(p * log(p/m))
kl_p_m = np.sum(p1 * np.log(p1 / m), axis=1)
kl_q_m = np.sum(p2 * np.log(p2 / m), axis=1)
# JSD = 0.5 * (KL(p||m) + KL(q||m))
jsd = 0.5 * (kl_p_m + kl_q_m)
return jsd
results = []
for site, group in df.groupby(site_col):
# Filter to valid mutations with both effects and stds
valid = group.dropna(subset=[
f'{ha_x}_effect', f'{ha_y}_effect',
f'{ha_x}_effect_std', f'{ha_y}_effect_std'
])
if len(valid) < min_mutations:
continue
# Get observed effects
effects_x = valid[f'{ha_x}_effect'].values
effects_y = valid[f'{ha_y}_effect'].values
# Get standard deviations
std_x = valid[f'{ha_x}_effect_std'].values
std_y = valid[f'{ha_y}_effect_std'].values
# Compute observed JSD between ha_x and ha_y
p_obs = np.exp(effects_x)
q_obs = np.exp(effects_y)
p_obs /= p_obs.sum()
q_obs /= q_obs.sum()
m_obs = 0.5 * (p_obs + q_obs)
jsd_obs = 0.5 * (kl_divergence(p_obs, m_obs) + kl_divergence(q_obs, m_obs))
# Only compute null distribution if JSD exceeds threshold
if jsd_obs <= jsd_threshold:
results.append({
'struct_site': site,
'JS_observed': jsd_obs,
'JS_null_mean': np.nan,
'JS_null_std': np.nan,
'p_value': np.nan,
'n_mutations': len(valid),
'null_distribution': None
})
continue
# Vectorized bootstrap null distributions
jsd_null_x = compute_jsd_vectorized(effects_x, std_x, n_bootstrap)
jsd_null_y = compute_jsd_vectorized(effects_y, std_y, n_bootstrap)
# Take the mean of the two nulls (balanced approach)
jsd_null = (jsd_null_x + jsd_null_y) / 2
# Compute empirical p-value (one-tailed test: is observed JSD greater than null?)
p_value = np.mean(jsd_null >= jsd_obs)
results.append({
'struct_site': site,
'JS_observed': jsd_obs,
'JS_null_mean': jsd_null.mean(),
'JS_null_std': jsd_null.std(),
'p_value': p_value,
'n_mutations': len(valid),
'null_distribution': jsd_null # Store for visualization
})
# Convert to DataFrame and sort by struct_site using natural sorting
results_df = pd.DataFrame(results)
results_df = results_df.sort_values('struct_site', key=natsort_keygen()).reset_index(drop=True)
return results_df
In [68]:
# Compute JSD with null distributions for each comparison
jsd_with_pvals_h3_h5 = compute_jsd_with_null(
js_df_h3_h5,
'h3', 'h5',
min_mutations=10,
n_bootstrap=1000,
jsd_threshold=0.02
)
jsd_with_pvals_h3_h7 = compute_jsd_with_null(
js_df_h3_h7,
'h3', 'h7',
min_mutations=10,
n_bootstrap=1000,
jsd_threshold=0.02
)
jsd_with_pvals_h5_h7 = compute_jsd_with_null(
js_df_h5_h7,
'h5', 'h7',
min_mutations=10,
n_bootstrap=1000,
jsd_threshold=0.02
)
# Apply multiple testing correction (Benjamini-Hochberg FDR)
# Only apply FDR to sites that were tested (non-NaN p-values)
from scipy.stats import false_discovery_control
def apply_fdr_with_threshold(df):
"""Apply FDR correction only to non-NaN p-values."""
# Initialize q_value column with NaN
df['q_value'] = np.nan
# Get indices of non-NaN p-values
tested_mask = df['p_value'].notna()
if tested_mask.sum() > 0:
# Apply FDR correction only to tested sites
df.loc[tested_mask, 'q_value'] = false_discovery_control(df.loc[tested_mask, 'p_value'])
return df
jsd_with_pvals_h3_h5 = apply_fdr_with_threshold(jsd_with_pvals_h3_h5)
jsd_with_pvals_h3_h7 = apply_fdr_with_threshold(jsd_with_pvals_h3_h7)
jsd_with_pvals_h5_h7 = apply_fdr_with_threshold(jsd_with_pvals_h5_h7)
# Report significant sites as fractions (out of ALL sites with JSD measurements)
print("Significant sites (H3 vs H5, q < 0.1):")
total_h3h5 = len(jsd_with_pvals_h3_h5)
sig_h3h5 = (jsd_with_pvals_h3_h5['q_value'] < 0.1).sum()
print(f" {sig_h3h5} / {total_h3h5} sites ({sig_h3h5/total_h3h5:.2%})")
print("\nSignificant sites (H3 vs H7, q < 0.1):")
total_h3h7 = len(jsd_with_pvals_h3_h7)
sig_h3h7 = (jsd_with_pvals_h3_h7['q_value'] < 0.1).sum()
print(f" {sig_h3h7} / {total_h3h7} sites ({sig_h3h7/total_h3h7:.2%})")
print("\nSignificant sites (H5 vs H7, q < 0.1):")
total_h5h7 = len(jsd_with_pvals_h5_h7)
sig_h5h7 = (jsd_with_pvals_h5_h7['q_value'] < 0.1).sum()
print(f" {sig_h5h7} / {total_h5h7} sites ({sig_h5h7/total_h5h7:.2%})")
Significant sites (H3 vs H5, q < 0.1): 294 / 467 sites (62.96%) Significant sites (H3 vs H7, q < 0.1): 207 / 467 sites (44.33%) Significant sites (H5 vs H7, q < 0.1): 189 / 431 sites (43.85%)
In [69]:
def plot_jsd(df, jsd_pvals_df, ha_x, ha_y, identity_df=None, alpha=0.1):
"""
Plot JSD values with significance coloring.
Parameters
----------
df : pd.DataFrame
Main dataframe with mutation effects
jsd_pvals_df : pd.DataFrame
DataFrame with JSD p-values and q-values from compute_jsd_with_null
identity_df : pd.DataFrame
DataFrame with sequence identity information
ha_x, ha_y : str
HA subtype names
alpha : float
Significance threshold for q-value (default 0.1)
"""
if identity_df is not None:
result = identity_df.query(
f'ha_x=="{ha_x.upper()}" and ha_y=="{ha_y.upper()}"'
)
shared_aai = result['percent_identity'].values[0] if len(result) > 0 else None
else:
shared_aai = None
amino_acid_classification = {
'F': 'Aromatic', 'Y': 'Aromatic', 'W': 'Aromatic',
'N': 'Hydrophilic', 'Q': 'Hydrophilic', 'S': 'Hydrophilic', 'T': 'Hydrophilic',
'A': 'Hydrophobic', 'V': 'Hydrophobic', 'I': 'Hydrophobic', 'L': 'Hydrophobic', 'M': 'Hydrophobic',
'D': 'Negative', 'E': 'Negative',
'R': 'Positive', 'H': 'Positive', 'K': 'Positive',
'C': 'Special', 'G': 'Special', 'P': 'Special'
}
df['struct_site'] = df['struct_site'].astype(str)
df = df.assign(
mutant_type=lambda x: x['mutant'].map(amino_acid_classification)
)
# Merge significance data with site-level JSD data
site_jsd_df = df[[
'struct_site', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa',
f'JS_{ha_x}_vs_{ha_y}', f'rmsd_{ha_x}{ha_y}'
]].dropna().drop_duplicates()
# Merge q-values
site_jsd_df = site_jsd_df.merge(
jsd_pvals_df[['struct_site', 'q_value']],
on='struct_site',
how='left'
)
# Add significance flag
site_jsd_df = site_jsd_df.assign(
significant=lambda x: x['q_value'] < alpha
)
variant_selector = alt.selection_point(
on="mouseover", empty=False, nearest=True, fields=["struct_site"], value=1
)
sorted_sites = natsorted(df['struct_site'].unique())
base = alt.Chart(site_jsd_df).encode(
alt.X(
"struct_site:O",
sort=sorted_sites,
title='Site',
axis=alt.Axis(
labelAngle=0,
values=['1', '50', '100', '150', '200', '250', '300', '350', '400', '450', '500'],
tickCount=11,
)
),
alt.Y(
f'JS_{ha_x}_vs_{ha_y}:Q',
title=['Divergence in amino-acid', 'preferences'],
axis=alt.Axis(
grid=False
),
scale=alt.Scale(domain=[0, 0.7])
),
tooltip=[
'struct_site',
f'{ha_x}_wt_aa',
f'{ha_y}_wt_aa',
alt.Tooltip(f'JS_{ha_x}_vs_{ha_y}', format='.4f'),
alt.Tooltip(f'rmsd_{ha_x}{ha_y}', format='.2f'),
alt.Tooltip('q_value', format='.4f'),
'significant'
],
).properties(
width=800,
height=150
)
line = base.mark_line(opacity=0.5, stroke='#999999', size=1)
# Points layer with conditional formatting based on hover and click
points = base.mark_circle(filled=True).encode(
size=alt.condition(
variant_selector,
alt.value(75), # when selected
alt.value(40) # default
),
color=alt.Color(
'significant:N',
title=['Significant', f'(FDR < {alpha})'],
scale=alt.Scale(domain=[True, False], range=['#E15759', '#BAB0AC']),
legend=alt.Legend(
titleFontSize=14,
labelFontSize=12
)
),
stroke=alt.condition(
variant_selector,
alt.value('black'),
alt.value(None)
),
strokeWidth=alt.condition(
variant_selector,
alt.value(1),
alt.value(0)
),
opacity=alt.condition(
variant_selector,
alt.value(1),
alt.value(0.75)
)
).add_params(
variant_selector
)
# Correlation between cell entry effects plot
# Filter based on hover (only if nothing clicked) or click
base_corr_chart = (alt.Chart(df)
.mark_text(size=20)
.encode(
alt.X(
f"{ha_x}_effect",
title=["Effect on cell entry", f"in {ha_x.upper()}"],
scale=alt.Scale(domain=[-6,1.5])
),
alt.Y(
f"{ha_y}_effect",
title=["Effect on cell entry", f"in {ha_y.upper()}"],
scale=alt.Scale(domain=[-6,1.5])
),
alt.Text('mutant'),
alt.Color('mutant_type',
title='Mutant type',
scale=alt.Scale(
domain=['Aromatic', 'Hydrophilic', 'Hydrophobic','Negative', 'Positive', 'Special'],
range=["#4e79a7","#f28e2c","#e15759","#76b7b2","#59a14f","#edc949"]
),
legend=alt.Legend(
titleFontSize=16,
labelFontSize=13
)
),
tooltip=['struct_site', 'mutant', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa',
f'{ha_x}_effect', f'{ha_x}_effect_std',
f'{ha_y}_effect', f'{ha_y}_effect_std',
f'JS_{ha_x}_vs_{ha_y}'],
)
.transform_filter(
variant_selector
)
.properties(
height=150,
width=150,
)
)
# Vertical line at x = 0
vline = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(x='x:Q')
# Horizontal line at y = 0
hline = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(y='y:Q')
corr_chart = vline + hline + base_corr_chart
# density plot
density = alt.Chart(
site_jsd_df
).transform_density(
density=f'JS_{ha_x}_vs_{ha_y}',
bandwidth=0.02,
extent=[0,1],
counts=True,
steps=200
).mark_area(opacity=1, color='#CCEBC5', stroke='black', strokeWidth=1).encode(
alt.X('value:Q', title=['Divergence in amino-acid', 'preferences']),
alt.Y('density:Q', title='Density').stack(None),
).properties(
width=200,
height=60
)
if shared_aai is not None:
title_text = f'{ha_x.upper()} vs. {ha_y.upper()} ({shared_aai:.1f}% AAI)'
else:
title_text = f'{ha_x.upper()} vs. {ha_y.upper()}'
# combine the bar and heatmaps
combined_chart = alt.vconcat(
(line + points), corr_chart, density
).resolve_scale(
y='independent',
x='independent',
color='independent'
)
combined_chart = combined_chart.properties(
title=alt.Title(title_text,
offset=0,
fontSize=18,
#subtitle=['Hover over sites to see mutation effects. Click to lock selection (double-click to clear).'],
subtitleFontSize=16,
anchor='middle'
)
)
return combined_chart
chart = plot_jsd(
js_df_h3_h5,
jsd_with_pvals_h3_h5,
'h3', 'h5',
seq_identity
)
chart.display()
In [70]:
chart = plot_jsd(
js_df_h3_h7,
jsd_with_pvals_h3_h7,
'h3', 'h7',
seq_identity
)
chart.display()
In [71]:
chart = plot_jsd(
js_df_h5_h7,
jsd_with_pvals_h5_h7,
'h5', 'h7',
seq_identity
)
chart.display()
In [72]:
js_df_h3_h5[[
'struct_site', 'h3_wt_aa', 'h5_wt_aa', 'rmsd_h3h5', '4o5n_aa_RSA', 'JS_h3_vs_h5'
]].drop_duplicates().reset_index(drop=True).to_csv(
'../results/divergence/h3_h5_divergence.csv', index=False
)
js_df_h3_h7[[
'struct_site', 'h3_wt_aa', 'h7_wt_aa', 'rmsd_h3h7', '4o5n_aa_RSA', 'JS_h3_vs_h7'
]].drop_duplicates().reset_index(drop=True).to_csv(
'../results/divergence/h3_h7_divergence.csv', index=False
)
js_df_h5_h7[[
'struct_site', 'h5_wt_aa', 'h7_wt_aa', 'rmsd_h5h7', '4o5n_aa_RSA', 'JS_h5_vs_h7'
]].drop_duplicates().reset_index(drop=True).to_csv(
'../results/divergence/h5_h7_divergence.csv', index=False
)
H7 2'6 vs. H7 2'3¶
In [73]:
def read_and_filter_data(
path,
effect_std_filter=2,
times_seen_filter=2,
n_selections_filter=2,
clip_effect=-5
):
print(f'Reading data from {path}')
print(
f"Filtering for:\n"
f" effect_std <= {effect_std_filter}\n"
f" times_seen >= {times_seen_filter}\n"
f" n_selections >= {n_selections_filter}"
)
print(f"Clipping effect values at {clip_effect}")
df = pd.read_csv(path).query(
'effect_std <= @effect_std_filter and \
times_seen >= @times_seen_filter and \
n_selections >= @n_selections_filter'
).query(
'mutant not in ["*", "-"]' # don't want stop codons/indels
)
df['site'] = df['site'].astype(str)
df['effect'] = df['effect'].clip(-5)
df = pd.concat([
df,
df[['site', 'wildtype']].drop_duplicates().assign(
mutant=lambda x: x['wildtype'],
effect=0.0,
effect_std=0.0,
times_seen=np.nan,
n_selections=np.nan
) # add wildtype sites with zero effect
], ignore_index=True).sort_values(['site', 'mutant']).reset_index(drop=True)
return df
In [74]:
h7_23_df = read_and_filter_data('../data/cell_entry_effects/293_2-3_entry_func_effects.csv')[[
'site', 'wildtype', 'mutant', 'effect', 'effect_std'
]].rename(
columns={
'site': 'struct_site',
'wildtype': 'h7_2-3_wt_aa',
'mutant': 'mutant',
'effect': 'h7_2-3_effect',
'effect_std': 'h7_2-3_effect_std'
}
)
h7_26_df = read_and_filter_data('../data/cell_entry_effects/293_2-6_entry_func_effects.csv')[[
'site', 'wildtype', 'mutant', 'effect', 'effect_std'
]].rename(
columns={
'site': 'struct_site',
'wildtype': 'h7_2-6_wt_aa',
'mutant': 'mutant',
'effect': 'h7_2-6_effect',
'effect_std': 'h7_2-6_effect_std'
}
)
h7_23_26_df = pd.merge(
h7_23_df,
h7_26_df,
left_on=['struct_site', 'h7_2-3_wt_aa', 'mutant'],
right_on=['struct_site', 'h7_2-6_wt_aa', 'mutant'],
).assign(
**{'rmsd_h7_2-3h7_2-6': 0}
)
h7_23_26_df.head()
Reading data from ../data/cell_entry_effects/293_2-3_entry_func_effects.csv Filtering for: effect_std <= 2 times_seen >= 2 n_selections >= 2 Clipping effect values at -5 Reading data from ../data/cell_entry_effects/293_2-6_entry_func_effects.csv Filtering for: effect_std <= 2 times_seen >= 2 n_selections >= 2 Clipping effect values at -5
Out[74]:
| struct_site | h7_2-3_wt_aa | mutant | h7_2-3_effect | h7_2-3_effect_std | h7_2-6_wt_aa | h7_2-6_effect | h7_2-6_effect_std | rmsd_h7_2-3h7_2-6 | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | 100 | G | A | -0.00205 | 0.9766 | G | -1.319 | 0.8159 | 0 |
| 1 | 100 | G | C | -3.91300 | 0.0130 | G | -4.422 | 0.0000 | 0 |
| 2 | 100 | G | D | -4.76800 | 0.0000 | G | -4.937 | 0.0000 | 0 |
| 3 | 100 | G | G | 0.00000 | 0.0000 | G | 0.000 | 0.0000 | 0 |
| 4 | 100 | G | H | -4.64600 | 0.0000 | G | -4.800 | 0.0000 | 0 |
In [75]:
js_df_h7_23_26 = compute_js_divergence_per_site(h7_23_26_df, 'h7_2-3', 'h7_2-6', min_mutations=10)
In [76]:
# Compute JSD with null distributions for each comparison
jsd_with_pvals_h7_23_26 = compute_jsd_with_null(
js_df_h7_23_26,
'h7_2-3', 'h7_2-6',
min_mutations=10,
n_bootstrap=1000,
jsd_threshold=0.02
)
jsd_with_pvals_h7_23_26 = apply_fdr_with_threshold(jsd_with_pvals_h7_23_26)
# Report significant sites as fractions (out of ALL sites with JSD measurements)
print("Significant sites (H7 2-3 vs H7 2-6, q < 0.1):")
total_h7_23_26 = len(jsd_with_pvals_h7_23_26)
sig_h7_23_26 = (jsd_with_pvals_h7_23_26['q_value'] < 0.1).sum()
print(f" {sig_h7_23_26} / {total_h7_23_26} sites ({sig_h7_23_26/total_h7_23_26:.2%})")
Significant sites (H7 2-3 vs H7 2-6, q < 0.1): 0 / 492 sites (0.00%)
In [77]:
chart = plot_jsd(
js_df_h7_23_26,
jsd_with_pvals_h7_23_26,
'h7_2-3', 'h7_2-6'
)
chart.display()
In [78]:
def plot_ridgeline_density(dfs_dict, x_col_template='JS_{ha_x}_vs_{ha_y}',
bandwidth=0.02, extent=[0,0.65],
colors=None, width=200, height=400,
overlap=2.5, label_mapping=None):
"""
Plot ridgeline (joyplot) density plots for multiple dataframes.
Parameters:
-----------
dfs_dict : dict
Dictionary where keys are comparison labels (e.g., 'h3-h5', 'h3-h7')
and values are tuples of (df, ha_x, ha_y)
Example: {'h3-h5': (js_df_h3_h5, 'h3', 'h5'),
'h3-h7': (js_df_h3_h7, 'h3', 'h7')}
x_col_template : str
Template for column name with {ha_x} and {ha_y} placeholders
bandwidth : float
Bandwidth for density estimation
extent : list
[min, max] for density calculation
colors : list or None
List of colors for each comparison. If None, uses default color scheme
width, height : int
Dimensions of the plot
overlap : float
How much the ridges overlap (higher = more overlap)
Returns:
--------
alt.Chart : Ridgeline density plot
"""
import pandas as pd
import altair as alt
# Default color scheme if none provided
if colors is None:
colors = ['#8DD3C7', '#FFFFB3', '#BEBADA', '#FB8072', '#80B1D3', '#FDB462']
# Combine all dataframes with a comparison label
combined_data = []
for i, (comparison, (df, ha_x, ha_y)) in enumerate(dfs_dict.items()):
col_name = x_col_template.format(ha_x=ha_x, ha_y=ha_y)
temp_df = df[[col_name]].copy()
temp_df['comparison'] = comparison
temp_df['value'] = temp_df[col_name]
combined_data.append(temp_df[['value', 'comparison']])
combined_df = pd.concat(combined_data, ignore_index=True)
if label_mapping is not None:
combined_df['comparison'] = combined_df['comparison'].map(label_mapping)
# Calculate step size for ridgeline spacing
step = height / (len(dfs_dict) * overlap)
# Create the ridgeline plot
ridgeline = alt.Chart(combined_df).transform_density(
density='value',
bandwidth=bandwidth,
extent=extent,
groupby=['comparison'],
steps=200
).transform_calculate(
# Offset each comparison vertically based on its order
yvalue='datum.density'
).mark_area(
opacity=1,
stroke='black',
strokeWidth=1,
interpolate='monotone'
).encode(
alt.X('value:Q', title=['Divergence in amino-acid', 'preferences']),
alt.Y('density:Q',
title='Density',
axis=None),
alt.Row('comparison:N',
title=None,
header=alt.Header(labelAngle=0, labelAlign='left')),
alt.Fill('comparison:N',
legend=None,
scale=alt.Scale(range=colors[:len(dfs_dict)]))
).properties(
width=width,
height=step,
bounds='flush'
).configure_facet(
spacing=-(step * (overlap - 1))
).configure_view(
stroke=None
).configure_header(
labelFontSize=14
)
return ridgeline
# Example usage:
dfs_to_plot = {
'h3-h5': (js_df_h3_h5, 'h3', 'h5'),
'h3-h7': (js_df_h3_h7, 'h3', 'h7'),
'h5-h7': (js_df_h5_h7, 'h5', 'h7'),
'h7_2-3-h7_2-6': (js_df_h7_23_26, 'h7_2-3', 'h7_2-6')
}
plot_ridgeline_density(
dfs_to_plot,
label_mapping={
'h3-h5': 'H3 vs. H5',
'h3-h7': 'H3 vs. H7',
'h5-h7': 'H5 vs. H7',
'h7_2-3-h7_2-6': ['H7 (a2,3) vs.', 'H7 (a2,6)']
}
).display()
Examples of mutation effect correlations¶
In [79]:
def plot_correlation(df, ha_x, ha_y, site, decimal_places=2):
amino_acid_classification = {
'F': 'Aromatic', 'Y': 'Aromatic', 'W': 'Aromatic',
'N': 'Hydrophilic', 'Q': 'Hydrophilic', 'S': 'Hydrophilic', 'T': 'Hydrophilic',
'A': 'Hydrophobic', 'V': 'Hydrophobic', 'I': 'Hydrophobic', 'L': 'Hydrophobic', 'M': 'Hydrophobic',
'D': 'Negative', 'E': 'Negative',
'R': 'Positive', 'H': 'Positive', 'K': 'Positive',
'C': 'Special', 'G': 'Special', 'P': 'Special'
}
df['struct_site'] = df['struct_site'].astype(str)
df = df.assign(
mutant_type=lambda x: x['mutant'].map(amino_acid_classification)
).query(f'struct_site == "{site}"')
jsd = df[f'JS_{ha_x}_vs_{ha_y}'].unique()[0]
base_corr_chart = (alt.Chart(df.query(f'struct_site == "{site}"'))
.mark_text(size=20)
.encode(
alt.X(
f"{ha_x}_effect",
title=["Effect on cell entry", f"in {ha_x.upper()}"],
scale=alt.Scale(domain=[-6,1.5])
),
alt.Y(
f"{ha_y}_effect",
title=["Effect on cell entry", f"in {ha_y.upper()}"],
scale=alt.Scale(domain=[-6,1.5])
),
alt.Text('mutant'),
alt.Color('mutant_type',
title='Mutant type',
scale=alt.Scale(
domain=['Aromatic', 'Hydrophilic', 'Hydrophobic','Negative', 'Positive', 'Special'],
range=["#4e79a7","#f28e2c","#e15759","#76b7b2","#59a14f","#edc949"]
),
legend=alt.Legend(
titleFontSize=16,
labelFontSize=13
)
),
tooltip=['struct_site', 'mutant', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa',
f'{ha_x}_effect', f'{ha_x}_effect_std',
f'{ha_y}_effect', f'{ha_y}_effect_std',
f'JS_{ha_x}_vs_{ha_y}'],
).properties(
height=125,
width=125,
)
)
# Vertical line at x = 0
vline = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(x='x:Q')
# Horizontal line at y = 0
hline = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(y='y:Q')
corr_chart = (vline + hline + base_corr_chart).properties(
title=alt.Title([f'Site {site}', f'Divergence = {jsd:.{decimal_places}f}'],
offset=0,
fontSize=16,
anchor='middle'
)
)
return corr_chart
In [80]:
(
plot_correlation(js_df_h3_h5, 'h3', 'h5', site='86') |
plot_correlation(js_df_h3_h5, 'h3', 'h5', site='97', decimal_places=4) |
plot_correlation(js_df_h3_h5, 'h3', 'h5', site='198') |
plot_correlation(js_df_h3_h5, 'h3', 'h5', site='241')
).display()
In [81]:
(
plot_correlation(js_df_h3_h5, 'h3', 'h5', site='86') |
plot_correlation(js_df_h3_h7, 'h3', 'h7', site='86') |
plot_correlation(js_df_h5_h7, 'h5', 'h7', site='86')
).display()
In [82]:
(
plot_correlation(js_df_h3_h5, 'h3', 'h5', site='173') |
plot_correlation(js_df_h3_h7, 'h3', 'h7', site='173') |
plot_correlation(js_df_h5_h7, 'h5', 'h7', site='173')
).display()
In [83]:
(
plot_correlation(js_df_h3_h5, 'h3', 'h5', site='178') |
plot_correlation(js_df_h3_h7, 'h3', 'h7', site='178') |
plot_correlation(js_df_h5_h7, 'h5', 'h7', site='178')
).display()
(
plot_correlation(js_df_h3_h5, 'h3', 'h5', site='123') |
plot_correlation(js_df_h3_h7, 'h3', 'h7', site='123') |
plot_correlation(js_df_h5_h7, 'h5', 'h7', site='123')
).display()
(
plot_correlation(js_df_h3_h5, 'h3', 'h5', site='176') |
plot_correlation(js_df_h3_h7, 'h3', 'h7', site='176') |
plot_correlation(js_df_h5_h7, 'h5', 'h7', site='176')
).display()
# H3 forms H bonds at 178, 123, 176, and 211.
# H5 and H7 do not form any H bonds in this region, and therefore tolerate many more amino acids.